{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix = 'https://f000.backblazeb2.com/file/malay-dataset/tagging/ontonotes5/'\n",
    "\n",
    "urls = \"\"\"\n",
    "augmentation-address-ontonotes5.json\n",
    "augmentation-event-ontonotes5.json\n",
    "augmentation-fac-ontonotes5.json\n",
    "augmentation-gpe-ontonotes5.json\n",
    "augmentation-language-ontonotes5.json\n",
    "augmentation-law-ontonotes5.json\n",
    "augmentation-loc-ontonotes5.json\n",
    "augmentation-norp-ontonotes5.json\n",
    "augmentation-org-ontonotes5.json\n",
    "augmentation-person-ontonotes5.json\n",
    "augmentation-product-ontonotes5.json\n",
    "augmentation-work-of-art-ontonotes5.json\n",
    "ontonotes5-train-test.json\n",
    "\"\"\"\n",
    "urls = list(filter(None, urls.split('\\n')))\n",
    "\n",
    "import os\n",
    "\n",
    "# uncomment to download\n",
    "# for url in urls:\n",
    "#     print(url)\n",
    "#     os.system(f'wget {prefix}{url}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['train_X', 'train_Y', 'test_X', 'test_Y'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('ontonotes5-train-test.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "test_X = data['test_X']\n",
    "train_Y = data['train_Y']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = [\n",
    "    {'Tag': 'OTHER', 'Description': 'other'},\n",
    "    {'Tag': 'ADDRESS', 'Description': 'Address of physical location.'},\n",
    "    {'Tag': 'PERSON', 'Description': 'People, including fictional.'},\n",
    "    {\n",
    "        'Tag': 'NORP',\n",
    "        'Description': 'Nationalities or religious or political groups.',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'FAC',\n",
    "        'Description': 'Buildings, airports, highways, bridges, etc.',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'ORG',\n",
    "        'Description': 'Companies, agencies, institutions, etc.',\n",
    "    },\n",
    "    {'Tag': 'GPE', 'Description': 'Countries, cities, states.'},\n",
    "    {\n",
    "        'Tag': 'LOC',\n",
    "        'Description': 'Non-GPE locations, mountain ranges, bodies of water.',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'PRODUCT',\n",
    "        'Description': 'Objects, vehicles, foods, etc. (Not services.)',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'EVENT',\n",
    "        'Description': 'Named hurricanes, battles, wars, sports events, etc.',\n",
    "    },\n",
    "    {'Tag': 'WORK_OF_ART', 'Description': 'Titles of books, songs, etc.'},\n",
    "    {'Tag': 'LAW', 'Description': 'Named documents made into laws.'},\n",
    "    {'Tag': 'LANGUAGE', 'Description': 'Any named language.'},\n",
    "    {\n",
    "        'Tag': 'DATE',\n",
    "        'Description': 'Absolute or relative dates or periods.',\n",
    "    },\n",
    "    {'Tag': 'TIME', 'Description': 'Times smaller than a day.'},\n",
    "    {'Tag': 'PERCENT', 'Description': 'Percentage, including \"%\".'},\n",
    "    {'Tag': 'MONEY', 'Description': 'Monetary values, including unit.'},\n",
    "    {\n",
    "        'Tag': 'QUANTITY',\n",
    "        'Description': 'Measurements, as of weight or distance.',\n",
    "    },\n",
    "    {'Tag': 'ORDINAL', 'Description': '\"first\", \"second\", etc.'},\n",
    "    {\n",
    "        'Tag': 'CARDINAL',\n",
    "        'Description': 'Numerals that do not fall under another type.',\n",
    "    },\n",
    "]\n",
    "d = [d['Tag'] for d in d]\n",
    "d = ['PAD', 'X'] + d\n",
    "tag2idx = {i: no for no, i in enumerate(d)}\n",
    "idx2tag = {no: i for no, i in enumerate(d)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "augmentation-org-ontonotes5.json\n",
      "35150 35150\n",
      "6300 6300\n",
      "augmentation-fac-ontonotes5.json\n",
      "17040 17040\n",
      "4260 4260\n",
      "augmentation-loc-ontonotes5.json\n",
      "13040 13040\n",
      "2460 2460\n",
      "augmentation-gpe-ontonotes5.json\n",
      "21060 21060\n",
      "0 0\n",
      "augmentation-work-of-art-ontonotes5.json\n",
      "4020 4020\n",
      "1020 1020\n",
      "augmentation-event-ontonotes5.json\n",
      "1665 1665\n",
      "0 0\n",
      "augmentation-person-ontonotes5.json\n",
      "37883 37883\n",
      "2530 2530\n",
      "augmentation-product-ontonotes5.json\n",
      "7040 7040\n",
      "1760 1760\n",
      "augmentation-law-ontonotes5.json\n",
      "12584 12584\n",
      "3161 3161\n",
      "augmentation-language-ontonotes5.json\n",
      "6860 6860\n",
      "0 0\n",
      "augmentation-norp-ontonotes5.json\n",
      "30660 30660\n",
      "5940 5940\n",
      "augmentation-address-ontonotes5.json\n",
      "57502 57502\n",
      "15106 15106\n"
     ]
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "augmented = glob('augmentation-*-ontonotes5.json')\n",
    "\n",
    "for f in augmented:\n",
    "    print(f)\n",
    "    with open(f) as fopen:\n",
    "        data = json.load(fopen)\n",
    "        \n",
    "    print(len(data.get('train_X', [])), len(data.get('train_Y', [])))\n",
    "    print(len(data.get('test_X', [])), len(data.get('test_Y', [])))\n",
    "    \n",
    "    train_X.extend(data.get('train_X', []))\n",
    "    train_Y.extend(data.get('train_Y', []))\n",
    "    test_X.extend(data.get('test_X', []))\n",
    "    test_Y.extend(data.get('test_Y', []))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.text.bpe import WordPieceTokenizer\n",
    "\n",
    "tokenizer = WordPieceTokenizer('BERT.wordpiece', do_lower_case = False)\n",
    "# tokenizer.tokenize('halo nama sayacomel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def XY(strings):\n",
    "    left_train, right_train = strings[0], strings[1]\n",
    "    X, Y, MASK = [], [], []\n",
    "    for i in tqdm(range(len(left_train))):\n",
    "        left = [d for d in left_train[i]]\n",
    "        right = [d for d in right_train[i]]\n",
    "        bert_tokens = ['[CLS]']\n",
    "        y = ['PAD']\n",
    "        for no, orig_token in enumerate(left):\n",
    "            t = tokenizer.tokenize(orig_token)\n",
    "            bert_tokens.extend(t)\n",
    "            if len(t):\n",
    "                y.append(right[no])\n",
    "            y.extend(['X'] * (len(t) - 1))\n",
    "        bert_tokens.append('[SEP]')\n",
    "        y.append('PAD')\n",
    "        x = tokenizer.convert_tokens_to_ids(bert_tokens)\n",
    "        y = [tag2idx[i] for i in y]\n",
    "        input_mask = [1] * len(y)\n",
    "        if len(x) != len(y):\n",
    "            print(i)\n",
    "        X.append(x)\n",
    "        Y.append(y)\n",
    "        MASK.append(input_mask)\n",
    "    return X, Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 571296/571296 [11:34<00:00, 822.33it/s] \n"
     ]
    }
   ],
   "source": [
    "o = XY([train_X, train_Y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 78839/78839 [01:40<00:00, 780.89it/s]\n"
     ]
    }
   ],
   "source": [
    "test_o = XY([test_X, test_Y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(571296, 571296, 78839, 78839)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(o[0]), len(o[1]), len(test_o[0]), len(test_o[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('ontonotes5-fastformer.pkl', 'wb') as fopen:\n",
    "    pickle.dump([o[0], o[1], test_o[0], test_o[1]], fopen)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
